Imports¶
In [1]:
import os
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
if not 'have_changed_cwd' in globals():
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")
have_changed_cwd = True
from src.models.classification.classifier import ImageClassifier
from src.dataset.cifar import CIFAR10DataModule
from src.models.autoencoding.autoencoder import Autoencoder
from src.utils import plot_image_grid
chkpt_path = "classifiers/resnet9-cifar10/2024-12-27/19-57-36/checkpoints/last.ckpt"
# chkpt_path = "classifiers/linear-cifar10/2024-12-27/20-41-01/checkpoints/last.ckpt"
classifier = ImageClassifier.load_from_checkpoint(chkpt_path)
chkpt_path = "autoencoders/convae-cifar10/2024-12-27/20-11-11/checkpoints/last.ckpt"
autoencoder = Autoencoder.load_from_checkpoint(chkpt_path)
datamodule = CIFAR10DataModule.get_default_dataset("cifar10", samples_per_class=100)
x, y = next(iter(datamodule.train_dataloader()))
N = 9
x = x[:N]
y = y[:N]
Working directory: /Users/mat/Desktop/Files/Code/Generative-Data-Augmentation Files already downloaded and verified Files already downloaded and verified
Classifier Predictions¶
In [58]:
x, y = next(iter(datamodule.val_dataloader()))
N = 16
x = x[:N]
y = y[:N]
In [15]:
from src.eval.classification import inference_with_classifier
out = inference_with_classifier(classifier, datamodule.test_dataloader(), device='mps')
Doing inference on 16 batches on mps: 100%|██████████| 16/16 [00:00<00:00, 41.14it/s]
In [16]:
from sklearn.metrics import classification_report
print(classification_report(out['target'], out['pred'], target_names=datamodule.class_names))
precision recall f1-score support
airplane 0.89 0.74 0.81 100
automobile 0.94 0.89 0.91 100
bird 0.78 0.90 0.83 100
cat 0.78 0.81 0.79 100
deer 0.87 0.87 0.87 100
dog 0.82 0.81 0.81 100
frog 0.95 0.80 0.87 100
horse 0.85 0.90 0.87 100
ship 0.91 0.86 0.89 100
truck 0.80 0.95 0.87 100
accuracy 0.85 1000
macro avg 0.86 0.85 0.85 1000
weighted avg 0.86 0.85 0.85 1000
In [8]:
from src.eval.classification import collect_misclassified, plot_image_grid
from matplotlib import pyplot as plt
misclassified_images, true_labels, pred_labels, pred_logits = collect_misclassified(classifier, datamodule.val_dataloader(), device='mps', num_samples=9)
In [9]:
plot_image_grid(misclassified_images, true_labels, pred_logits, class_names=datamodule.class_names)
plt.show()
Autoencoder predictions¶
In [ ]:
x, y = next(iter(datamodule.val_dataloader()))
N = 16
x = x[:N]
y = y[:N]
In [10]:
from src.eval.autoencoding import inference_with_autoencoder, collect_high_error_reconstructions, plot_reconstruction_pairs
out = inference_with_autoencoder(autoencoder, datamodule.test_dataloader(), device='mps')
In [11]:
plt.hist(out['mse_losses'], bins=100)
plt.title(f"Reconstruction Error Histogram (MSE) - mean: {out['mse_losses'].mean():.4f}")
plt.grid()
plt.show()
In [12]:
original_inputs, reconstructions, latent_codes, mse_values = collect_high_error_reconstructions(autoencoder, datamodule.test_dataloader(), device='mps', threshold=0.015, num_samples=9)
In [13]:
plot_reconstruction_pairs(original_inputs, reconstructions, num_pairs=5, figsize=(8, 8), mse_losses=mse_values)
plt.show()
Gradients¶
In [ ]:
x, y = next(iter(datamodule.val_dataloader()))
N = 16
x = x[:N]
y = y[:N]
Pixel Space¶
In [ ]:
num_steps = 300
lr = 0.001
weight_decay = 0.01
optimizer_cls = torch.optim.AdamW
In [ ]:
from src.input_gradients import compute_proba_grads_wrt_data, plot_grads_wrt_data
grads, _ = compute_proba_grads_wrt_data(classifier, x.clone(), list(range(10)), device='mps')
plot_grads_wrt_data(x, grads, list(range(10)), datamodule.class_names)
plt.show()
In [22]:
from src.input_gradients import optimize_proba_wrt_data, plot_optimal_images
config = {
"num_steps": num_steps,
"optimizer_cls": optimizer_cls,
"save_k_intermediate_imgs": 10,
"logit_transform": None,
"optimizer_kwargs": {
"lr": lr,
"weight_decay": weight_decay,
},
}
out = optimize_proba_wrt_data(
classifier,
x.clone(),
list(range(10)),
config,
device='mps',
)
plot_optimal_images(out, datamodule.class_names)
plt.show()
In [23]:
from src.input_gradients import plot_optimization_trajectory_fixed_sample, plot_optimization_trajectory_fixed_target
plot_optimization_trajectory_fixed_target(out, 2, datamodule.class_names)
plt.show()
In [24]:
plot_optimization_trajectory_fixed_sample(out, 0, datamodule.class_names, targets=[0, 1, 2], figsizes=((15, 8), (15, 12)))
plt.show()
Latent Space¶
In [38]:
num_steps = 500
lr = 0.001
weight_decay = 0.03
optimizer_cls = torch.optim.AdamW
In [ ]:
from src.input_gradients import compute_proba_grads_wrt_data, plot_grads_wrt_data
grads, finite_diffs = compute_proba_grads_wrt_data(classifier, x.clone(), list(range(10)), device='mps', autoencoder=autoencoder, epsilon=0.0001)
plot_grads_wrt_data(x, finite_diffs, list(range(10)), datamodule.class_names)
plt.show()
In [40]:
from src.input_gradients import optimize_proba_wrt_data, plot_optimal_images
config = {
"num_steps": num_steps,
"optimizer_cls": optimizer_cls,
"save_k_intermediate_imgs": 10,
"logit_transform": None,
"optimizer_kwargs": {
"lr": lr,
"weight_decay": weight_decay,
},
}
out = optimize_proba_wrt_data(
classifier,
x.clone(),
list(range(10)),
config,
device='mps',
autoencoder=autoencoder,
)
plot_optimal_images(out, datamodule.class_names)
plt.show()
In [41]:
from src.input_gradients import plot_optimization_trajectory_fixed_sample, plot_optimization_trajectory_fixed_target
plot_optimization_trajectory_fixed_target(out, 2, datamodule.class_names)
plt.show()
In [42]:
plot_optimization_trajectory_fixed_sample(out, 0, datamodule.class_names, targets=[0, 1, 2], figsizes=((15, 8), (15, 12)))
plt.show()
In [ ]:
torch.save(out, "optimal_images.pth")
In [46]:
out_loaded = torch.load("optimal_images.pth")
/var/folders/jy/x5558th97mjgtzp6f33ryf840000gn/T/ipykernel_59460/595520248.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
out_loaded = torch.load("optimal_images.pth")
Flatness¶
In [2]:
x, y = next(iter(datamodule.val_dataloader()))
N = 3
x = x[:N]
y = y[:N]
In [12]:
from src.eval.flatness import compute_local_energy, plot_local_energy
local_energy = compute_local_energy(
classifier,
datamodule.test_dataloader(),
1000,
None,
np.linspace(0.0, 0.5, 6),
10,
'mps'
)
Computing Local Energy for various noise levels: 100%|██████████| 6/6 [00:17<00:00, 2.84s/it]
In [ ]:
fig = plot_local_energy(local_energy)
plt.show()
In [3]:
from src.eval.flatness import compute_input_flatness, plot_inputs_flatness
results, gt = compute_input_flatness(
classifier,
[(x, y)],
2,
np.linspace(0.0, 1.0, 11),
100,
'mps'
)
Processing batches: 0%| | 0/1 [00:03<?, ?it/s]
In [4]:
plot_inputs_flatness(results, gt, -1, class_names=datamodule.class_names, filter_misclassified=True, print_summary=False, plot_global_average=True, plot_individual_samples=True)
plt.show()
In [5]:
from src.eval.flatness import visualize_input_noise
visualize_input_noise(x)
plt.show()
In [6]:
from src.eval.input_gradients import optimize_proba_wrt_data, plot_optimal_images
config = {
"num_steps": 300,
"optimizer_cls": torch.optim.AdamW,
"save_k_intermediate_imgs": 10,
"logit_transform": None,
"perturb_weights": True,
"stddev": 0.2,
"weights_sample_freq": 10,
"optimizer_kwargs": {
"lr": 0.001,
"weight_decay": 0.01,
},
}
out = optimize_proba_wrt_data(
classifier,
x.clone(),
list(range(10)),
config,
device='mps',
)
plot_optimal_images(out, datamodule.class_names)
plt.show()
Optimizing target 0: 100%|██████████| 300/300 [00:02<00:00, 115.66it/s] Optimizing target 1: 100%|██████████| 300/300 [00:02<00:00, 142.51it/s] Optimizing target 2: 100%|██████████| 300/300 [00:02<00:00, 138.18it/s] Optimizing target 3: 100%|██████████| 300/300 [00:02<00:00, 142.88it/s] Optimizing target 4: 100%|██████████| 300/300 [00:02<00:00, 139.71it/s] Optimizing target 5: 100%|██████████| 300/300 [00:02<00:00, 142.34it/s] Optimizing target 6: 100%|██████████| 300/300 [00:02<00:00, 128.53it/s] Optimizing target 7: 100%|██████████| 300/300 [00:02<00:00, 131.86it/s] Optimizing target 8: 100%|██████████| 300/300 [00:02<00:00, 140.45it/s] Optimizing target 9: 100%|██████████| 300/300 [00:02<00:00, 138.63it/s]
Genuine vs Adversarial examples¶
In [2]:
x, y = next(iter(datamodule.val_dataloader()))
N = 16
x = x[:N]
y = y[:N]
In [5]:
meanss = []
stdss = []
labels = []
In [3]:
plot_image_grid(x, y, datamodule.class_names, figsize=(4, 4))
plt.show()
In [4]:
from src.eval.flatness import compute_input_flatness, plot_inputs_flatness
results, gt = compute_input_flatness(
classifier,
datamodule.test_dataloader(),
1000,
np.linspace(0.0, 1.0, 11),
9,
'mps'
)
Processing batches: 94%|█████████▍| 15/16 [00:22<00:01, 1.51s/it]
In [6]:
fig1, fig2, fig3, means, stds = plot_inputs_flatness(results, gt, -1, class_names=datamodule.class_names, filter_misclassified=True, plot_global_average=True)
plt.show()
meanss.append(means)
stdss.append(stds)
labels.append("Genuine Samples")
Adversarial pixel space¶
In [7]:
from src.eval.input_gradients import optimize_proba_wrt_data, plot_optimal_images
config = {
"num_steps": 300,
"optimizer_cls": torch.optim.AdamW,
"save_k_intermediate_imgs": 10,
"logit_transform": None,
"perturb_weights": False,
"stddev": 0.2,
"weights_sample_freq": 10,
"optimizer_kwargs": {
"lr": 0.001,
"weight_decay": 0.01,
},
}
out = optimize_proba_wrt_data(
classifier,
x.clone(),
list(range(10)),
config,
device='mps',
)
Optimizing target 0: 0%| | 0/300 [00:00<?, ?it/s]
Optimizing target 0: 100%|██████████| 300/300 [00:04<00:00, 64.61it/s] Optimizing target 1: 100%|██████████| 300/300 [00:04<00:00, 67.04it/s] Optimizing target 2: 100%|██████████| 300/300 [00:04<00:00, 67.97it/s] Optimizing target 3: 100%|██████████| 300/300 [00:04<00:00, 67.38it/s] Optimizing target 4: 100%|██████████| 300/300 [00:04<00:00, 67.71it/s] Optimizing target 5: 100%|██████████| 300/300 [00:04<00:00, 67.63it/s] Optimizing target 6: 100%|██████████| 300/300 [00:04<00:00, 67.06it/s] Optimizing target 7: 100%|██████████| 300/300 [00:04<00:00, 64.29it/s] Optimizing target 8: 100%|██████████| 300/300 [00:04<00:00, 67.96it/s] Optimizing target 9: 100%|██████████| 300/300 [00:04<00:00, 67.17it/s]
In [8]:
adv_images = torch.cat([out["final_imgs"][i] for i in range(10)], dim=0)
adv_labels = torch.cat([torch.tensor([i for _ in range(N)]) for i in range(10)], dim=0)
adv_images.shape, adv_labels.shape
Out[8]:
(torch.Size([160, 3, 32, 32]), torch.Size([160]))
In [9]:
adv_results, adv_gt = compute_input_flatness(
classifier,
[(adv_images, adv_labels)],
None,
np.linspace(0.0, 1.0, 11),
9,
'mps'
)
Processing batches: 100%|██████████| 1/1 [00:03<00:00, 3.37s/it]
In [10]:
fig1, fig2, fig3, means, stds = plot_inputs_flatness(adv_results, adv_gt, -1, class_names=datamodule.class_names, filter_misclassified=True, plot_global_average=True)
plt.show()
meanss.append(means)
stdss.append(stds)
labels.append("Adversarial Pixels")
Adversarial latent space¶
In [11]:
from src.eval.input_gradients import optimize_proba_wrt_data, plot_optimal_images
config = {
"num_steps": 600,
"optimizer_cls": torch.optim.AdamW,
"save_k_intermediate_imgs": 10,
"logit_transform": None,
"perturb_weights": False,
"stddev": 0.2,
"weights_sample_freq": 10,
"optimizer_kwargs": {
"lr": 0.001,
"weight_decay": 0.01,
},
}
out = optimize_proba_wrt_data(
classifier,
x.clone(),
list(range(10)),
config,
device='mps',
autoencoder=autoencoder,
)
Optimizing target 0: 100%|██████████| 600/600 [00:46<00:00, 13.04it/s] Optimizing target 1: 100%|██████████| 600/600 [00:46<00:00, 13.03it/s] Optimizing target 2: 100%|██████████| 600/600 [00:45<00:00, 13.17it/s] Optimizing target 3: 100%|██████████| 600/600 [00:45<00:00, 13.27it/s] Optimizing target 4: 100%|██████████| 600/600 [00:45<00:00, 13.27it/s] Optimizing target 5: 100%|██████████| 600/600 [00:45<00:00, 13.27it/s] Optimizing target 6: 100%|██████████| 600/600 [00:46<00:00, 13.02it/s] Optimizing target 7: 100%|██████████| 600/600 [00:45<00:00, 13.16it/s] Optimizing target 8: 100%|██████████| 600/600 [00:45<00:00, 13.11it/s] Optimizing target 9: 100%|██████████| 600/600 [00:45<00:00, 13.05it/s]
In [12]:
adv_images = torch.cat([out["final_imgs"][i] for i in range(10)], dim=0)
adv_labels = torch.cat([torch.tensor([i for _ in range(N)]) for i in range(10)], dim=0)
adv_images.shape, adv_labels.shape
Out[12]:
(torch.Size([160, 3, 32, 32]), torch.Size([160]))
In [13]:
adv_results, adv_gt = compute_input_flatness(
classifier,
[(adv_images, adv_labels)],
None,
np.linspace(0.0, 1.0, 11),
9,
'mps'
)
Processing batches: 100%|██████████| 1/1 [00:03<00:00, 3.57s/it]
In [14]:
fig1, fig2, fig3, means, stds = plot_inputs_flatness(adv_results, adv_gt, -1, class_names=datamodule.class_names, filter_misclassified=True, plot_global_average=True)
plt.show()
meanss.append(means)
stdss.append(stds)
labels.append("Adversarial Latent")
With classifier noise¶
In [15]:
from src.eval.input_gradients import optimize_proba_wrt_data, plot_optimal_images
config = {
"num_steps": 300,
"optimizer_cls": torch.optim.AdamW,
"save_k_intermediate_imgs": 10,
"logit_transform": None,
"perturb_weights": True,
"stddev": 0.2,
"weights_sample_freq": 10,
"optimizer_kwargs": {
"lr": 0.001,
"weight_decay": 0.01,
},
}
out = optimize_proba_wrt_data(
classifier,
x.clone(),
list(range(10)),
config,
device='mps',
)
Optimizing target 0: 100%|██████████| 300/300 [00:04<00:00, 61.80it/s] Optimizing target 1: 100%|██████████| 300/300 [00:04<00:00, 63.31it/s] Optimizing target 2: 100%|██████████| 300/300 [00:04<00:00, 65.58it/s] Optimizing target 3: 100%|██████████| 300/300 [00:04<00:00, 65.69it/s] Optimizing target 4: 100%|██████████| 300/300 [00:04<00:00, 65.79it/s] Optimizing target 5: 100%|██████████| 300/300 [00:04<00:00, 65.49it/s] Optimizing target 6: 100%|██████████| 300/300 [00:04<00:00, 65.63it/s] Optimizing target 7: 100%|██████████| 300/300 [00:04<00:00, 65.77it/s] Optimizing target 8: 100%|██████████| 300/300 [00:04<00:00, 63.50it/s] Optimizing target 9: 100%|██████████| 300/300 [00:04<00:00, 65.63it/s]
In [16]:
adv_images = torch.cat([out["final_imgs"][i] for i in range(10)], dim=0)
adv_labels = torch.cat([torch.tensor([i for _ in range(N)]) for i in range(10)], dim=0)
adv_images.shape, adv_labels.shape
Out[16]:
(torch.Size([160, 3, 32, 32]), torch.Size([160]))
In [17]:
adv_results, adv_gt = compute_input_flatness(
classifier,
[(adv_images, adv_labels)],
None,
np.linspace(0.0, 1.0, 11),
9,
'mps'
)
Processing batches: 100%|██████████| 1/1 [00:03<00:00, 3.43s/it]
In [18]:
fig1, fig2, fig3, means, stds = plot_inputs_flatness(adv_results, adv_gt, -1, class_names=datamodule.class_names, filter_misclassified=True, plot_global_average=True)
plt.show()
meanss.append(means)
stdss.append(stds)
labels.append("Adversarial Pixel noisy weights")
In [25]:
fig, ax = plt.subplots(figsize=(8, 6))
colors = ["red", "blue", "yellow", "black"]
for means, stds, label, color in zip(meanss, stdss, labels, colors):
ax.errorbar(np.linspace(0.0, 1.0, 11), means, yerr=stds, label=label, color=color)
plt.xticks(np.linspace(0.0, 1.0, 11))
plt.yticks(np.linspace(0.0, 1.0, 11))
plt.xlabel("Noise Level")
plt.ylabel("probability")
plt.title("Flatness of the input landscape")
plt.grid()
plt.legend()
plt.show()